import math
import numpy as np
import gurobipy as grb

class Routing():
    """
    Desc: 流量工程做routing，顺便计算link utilization
    """
 
    _w_routing = {}
    
    def __init__(self, topology_pods):
        """
        Desc: 做routing前，除了需要加载基类的配置读取，配置其他参数
              还需要添加前面的拓扑工程计算出来的pods间连接数
        Inputs:
            - topology_pods(2-dimension list): Number of s_i egress links connected to igress links of s_j.
        """
        self._topology_pods = topology_pods


    
    def traffic_engineering(self, actual_traffic, Sij, virtual_r):
        """用gurobi来解
        """
        if not isinstance(actual_traffic, np.ndarray):
            traffic = np.array(actual_traffic)
        else:
            traffic = actual_traffic

        pods_num = self._pods_num

        name_w_ikj = [
            f'w_{i}_{k}_{j}'
            for i in range(1, pods_num + 1)
            for k in range(pods_num + 1)
            for j in range(1, pods_num + 1)
            if i != k and i != j and j != k
        ]
        routing = {k : 0 for k in name_w_ikj}

        R = virtual_r.sum()
        
        for i in range(1, pods_num + 1):
            for j in range(1, pods_num + 1):
                if i != j:
                    if traffic[i - 1][j - 1] > Sij[i - 1][j - 1]:
                        # 这样会少流量
                        # routing[f'w_{i}_{0}_{j}'] = Sij[i - 1][j - 1] / traffic[i - 1][j - 1]
                        sum_Wikj = 0
                        for k in range(1, pods_num + 1):
                            if k != i and k != j:
                                if traffic[i - 1][j - 1] > 0:
                                    Wikj = max(0.0, traffic[i - 1][j - 1] - Sij[i - 1][j - 1]) \
                                        * virtual_r[k - 1] / (R - min(virtual_r[i - 1], virtual_r[j - 1]))
                                    routing[f'w_{i}_{k}_{j}'] = Wikj / traffic[i - 1][j - 1]
                                    sum_Wikj += Wikj
                                else:
                                    routing[f'w_{i}_{k}_{j}'] = 0
                        if traffic[i - 1][j - 1] > 0:
                            routing[f'w_{i}_{0}_{j}'] = (traffic[i - 1][j - 1] - sum_Wikj) / traffic[i - 1][j - 1]
                    else:
                        routing[f'w_{i}_{0}_{j}'] = 1
     
        self._w_routing = routing

        # test
        # for i in range(1, pods_num + 1):
        #     for j in range(1, pods_num + 1):
        #         if i != j:
        #             w_s = routing[f'w_{i}_{0}_{j}']
        #             print(routing[f'w_{i}_{0}_{j}'])
        #             for k in range(1, pods_num + 1):
        #                 if k != i and k != j:
        #                     w_s += routing[f'w_{i}_{k}_{j}']
        #             print(i, j, w_s)
        # exit()
        return routing


    def routing(self, Sij, virtual_r, a_link_bandwidth, link_num):
        # 之前的traffic算出threshold来，该方法下，这个threshold就相当于完成了routing算法
        # 后面的步骤仅仅为了给link utilization计算提供方便
        # (shizhenzhao): 我这里直接改成了固定的threshold，所以last traffic没有用了
        new_Sij = self.scale_Sij(Sij, virtual_r, a_link_bandwidth, link_num)
        threshold = new_Sij
        return threshold
        # import pandas as pd
        # res = pd.DataFrame(threshold).astype(int)
        # res.to_csv('/Users/cpr/Desktop/sync-ubuntu/project/threshold-routing/traffic/8pod_threshold.csv', header=None, index=False)
        # exit()
        # threshold = self.find_threshold(last_traffic, new_Sij, virtual_r)

        # 提供给需要evaluate的新的traffic来计算出routing比例，然后外部再直接掉计算link utilization的函数
        # w_routing = self.traffic_engineering(cur_traffic, threshold, virtual_r)
        # return w_routing


    def scale_Sij(self, Sij, virtual_r, a_link_bandwidth, link_num):
        record_shape = Sij.shape

        capacities = self._topology_pods * a_link_bandwidth

        m = grb.Model('scale_Sij')
        m.Params.OutputFlag = 0
        alpha = m.addVar(vtype=grb.GRB.CONTINUOUS, lb=0, name='alpha')
        
        m.addConstrs(
            alpha * (Sij[i][j] + virtual_r[i] * virtual_r[j]) <= capacities[i][j]
            for i in range(record_shape[0])
            for j in range(record_shape[1])
            if i != j
        )

        m.setObjective(alpha, grb.GRB.MAXIMIZE)
        m.optimize()
        if m.status == grb.GRB.Status.OPTIMAL:
            alpha_opt = m.objVal
        else:
            print('scale_Sij No solution')
            exit()

        # （shizhenzhao）我这里计算了scale之后的Sij
        new_Sij = Sij
        for i in range(record_shape[0]):
            for j in range(record_shape[1]):
                if i != j:
                    new_Sij[i][j] = capacities[i][j] / alpha_opt - virtual_r[i] * virtual_r[j]

        # （shizhenzhao）我这里直接返回了需要用到的threshold
        return new_Sij * alpha_opt


    def find_threshold(self, traffic, Sij, virtual_r):

        record_shape = Sij.shape

        m = grb.Model('find_threshold')
        m.Params.OutputFlag = 0

        beta = m.addVar(vtype=grb.GRB.CONTINUOUS, name='beta')

        R = virtual_r.sum()

        inter_u = m.addVars(record_shape[0], record_shape[1], lb=0, vtype=grb.GRB.CONTINUOUS, name='u')
      
        m.addConstrs(
            inter_u[i, j] >= traffic[i][j] - beta * Sij[i][j]
            for i in range(record_shape[0])
            for j in range(record_shape[1])
            if i != j
        )

        m.addConstrs(
            grb.quicksum(
                inter_u[i, j] for j in range(record_shape[1]) if i != j
            ) <= beta * virtual_r[i] * (R - virtual_r[i])
            for i in range(record_shape[0])
        )

        m.addConstrs(
            grb.quicksum(
                inter_u[i, j] for i in range(record_shape[1]) if i != j
            ) <= beta * virtual_r[j] * (R - virtual_r[j])
            for j in range(record_shape[1])
        )

        m.setObjective(beta, grb.GRB.MINIMIZE)
        # m.write('find_threshold.lp')
        m.optimize()
        if m.status == grb.GRB.Status.OPTIMAL:
            beta_opt = m.objVal
        else:
            print('find_threshold No solution')
            exit()

        threshold = beta_opt * Sij
        return threshold
